/******************************************************************************
This file is part of AppKit.
Project: appkit
Author : FergusZeng
Email  : cblock@126.com
git	   : https://gitee.com/newgolo/appkit.git
*******************************************************************************
MIT License

Copyright (c) 2022 cblock@126.com

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
******************************************************************************/
#include "appkit/network.h"

#include <arpa/inet.h>
#include <net/if.h>
#include <net/if_arp.h>
#include <netdb.h>
#include <netinet/in.h>
#include <stdlib.h>
#include <string.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>

#include "appkit/fileio.h"
#include "appkit/strutil.h"
#include "appkit/system.h"
#include "appkit/tracer.h"

namespace appkit {
NetUtil::NetUtil() {}

NetUtil::~NetUtil() {}

bool NetUtil::isValidIP(const std::string& ip) {
    auto splits = StrUtil::splitString(ip, ":");
    if (splits.empty()) {
        return false;
    }
    struct in_addr addr;
    int rc = inet_pton(AF_INET, CSTR(splits[0]), &addr);
    if (rc > 0) {
        return true;
    } else {
        return false;
    }
}

bool NetUtil::checkSubnetIP(const std::string& gatewayIP,
                            const std::string& subnetMask,
                            std::string* subnetIP) {
    struct in_addr ip, mask, sub, result;
    if (inet_pton(AF_INET, CSTR(gatewayIP), &ip) <= 0) {
        return false;
    }

    if (inet_pton(AF_INET, CSTR(subnetMask), &mask) <= 0) {
        return false;
    }

    if (inet_pton(AF_INET, CSTR(*subnetIP), &sub) <= 0) {
        return false;
    }

    if ((sub.s_addr & mask.s_addr) == (ip.s_addr & mask.s_addr)) {
        return true;
    }

    result.s_addr = (ip.s_addr & mask.s_addr);
    result.s_addr = result.s_addr | ((~mask.s_addr) & sub.s_addr);
    char tmpStr[16] = {0};
    if (!inet_ntop(AF_INET, &result, tmpStr, 16)) {
        return false;
    }
    *subnetIP = std::string(tmpStr);
    return true;
}

std::string NetUtil::getSubnetIP(const std::string& gatewayIP,
                                 const std::string& subnetMask, int subIndex) {
    struct in_addr ip, mask, result;
    if (inet_pton(AF_INET, CSTR(gatewayIP), &ip) <= 0) {
        return "";
    }
    if (inet_pton(AF_INET, CSTR(subnetMask), &mask) <= 0) {
        return "";
    }

    result.s_addr = ip.s_addr & mask.s_addr;
    result.s_addr = result.s_addr & (subIndex << 24);
    char tmp[16] = {0};
    if (!inet_ntop(AF_INET, &result, tmp, 16)) {
        return "";
    }
    return std::string(tmp);
}

std::vector<std::string> NetUtil::getInterfaces() {
    std::vector<std::string> intfList;
    int sockfd = socket(AF_INET, SOCK_DGRAM, 0);
    if (sockfd < 0) {
        return intfList;
    }

    struct ifreq buf[16];
    struct ifconf ifc;
    ifc.ifc_len = sizeof(buf);
    ifc.ifc_buf = (caddr_t)buf;
    if (ioctl(sockfd, SIOCGIFCONF, reinterpret_cast<char*>(&ifc)) < 0) {
        TRACE_ERR("NetUtil::getInterfaces, SIOCGIFCONF error: %s!", ERRSTR);
        return intfList;
    }

    int ifLen = ifc.ifc_len / sizeof(struct ifreq);
    for (auto i = 0; i < ifLen; ++i) {
        intfList.push_back(std::string(buf[i].ifr_name));
    }
    return intfList;
}

// 使用命令行获取网卡IP地址:
// ifconfig wlan0 | grep inet | grep -v inet6 | awk '{print $2}'
std::string NetUtil::getInetAddr(const std::string& intf) {
    if (intf.size() >= IFNAMSIZ) {
        return "";
    }
    auto f1 = [intf]() -> std::string {
        struct ifreq ifr;
        strncpy(ifr.ifr_name, intf.data(), intf.size());
        int sockfd = socket(AF_INET, SOCK_DGRAM, 0);
        if (sockfd < 0) {
            return "";
        }

        if (ioctl(sockfd, SIOCGIFADDR, &ifr) < 0) {
            TRACE_ERR("NetUtil::getInetAddr(%s), SIOCGIFADDR error: %s!",
                      intf.data(), ERRSTR);
            close(sockfd);
            return "";
        }
        struct sockaddr_in* addr = (struct sockaddr_in*)&(ifr.ifr_addr);
        char* ipstr = inet_ntoa(addr->sin_addr);
        if (!ipstr) {
            TRACE_ERR("NetUtil::getInetAddr(%s), inet_ntoa error!",
                      intf.data());
            close(sockfd);
            return "";
        }
        close(sockfd);
        return std::string(ipstr);
    };
    auto f2 = [intf]() -> std::string {
        auto cmd = StrUtil::format(
            "ifconfig %s | grep inet | grep -v inet6 | awk '{print $2}'",
            intf.data());
        std::string result;
        if (RC_OK != System::execute(cmd, &result)) {
            return "";
        }
        return StrUtil::trimTailBlank(result);
    };
    auto ipAddr = f1();
    if (ipAddr.empty()) {
        return f2();
    }
    return ipAddr;
}

std::string NetUtil::getMaskAddr(const std::string& intf) {
    if (intf.size() >= IFNAMSIZ) {
        return "";
    }

    struct ifreq ifr;
    strncpy(ifr.ifr_name, intf.data(), intf.size());
    int sockfd = socket(AF_INET, SOCK_DGRAM, 0);
    if (sockfd < 0) {
        return "";
    }

    if (ioctl(sockfd, SIOCGIFNETMASK, &ifr) < 0) {
        TRACE_ERR("NetUtil::getMaskAddr(%s),SIOCGIFNETMASK error: %s!",
                  intf.data(), ERRSTR);
        close(sockfd);
        return "";
    }
    struct sockaddr_in* addr = (struct sockaddr_in*)&(ifr.ifr_addr);
    char* ipstr = inet_ntoa(addr->sin_addr);
    if (!ipstr) {
        TRACE_ERR("NetUtil::getMaskAddr(%s),inet_ntoa error!", intf.data());
        close(sockfd);
        return "";
    }
    std::string netmask = ipstr;
    close(sockfd);
    return netmask;
}

std::vector<uint8> NetUtil::getMacAddr(const std::string& intf) {
    std::vector<uint8> macAddr;
    if (intf.size() >= IFNAMSIZ) {
        return macAddr;
    }

    auto f1 = [intf]() {
        std::vector<uint8> macVect;
        struct ifreq ifr;
        strncpy(ifr.ifr_name, intf.data(), intf.size());
        int sockfd = socket(AF_INET, SOCK_DGRAM, 0);
        if (sockfd < 0) {
            return macVect;
        }
        if (ioctl(sockfd, SIOCGIFHWADDR, &ifr) < 0) {
            TRACE_ERR("NetUtil::getMacAddr(%s), SIOCGIFHWADDR error: %s!",
                      intf.data(), ERRSTR);
            close(sockfd);
            return macVect;
        }
        close(sockfd);
        for (auto i = 0; i < 6; i++) {
            macVect.push_back(ifr.ifr_hwaddr.sa_data[i]);
        }
        return macVect;
    };

    auto f2 = [intf]() {
        std::vector<uint8> macVect;
        auto macStr = File::readAll(
            StrUtil::format("/sys/class/net/%s/address", intf.data()));
        if (macStr.empty()) {
            return macVect;
        }
        macStr = StrUtil::trimTailBlank(macStr);  // 去掉首尾非可见字符
        auto vect = StrUtil::splitString(macStr, ":", false);
        for (auto& hex : vect) {
            macVect.push_back(StrUtil::htoi(hex));
        }
        return macVect;
    };

    macAddr = f1();
    if (macAddr.empty()) {
        return f2();
    }
    return macAddr;
}

NetHwState NetUtil::getHwState(const std::string& intf) {
    int rc;
    struct ifreq ifr;
    int sockfd;
    if (intf.size() >= IFNAMSIZ) {
        return NetHwState::Error;
    }
    strncpy(ifr.ifr_name, intf.data(), intf.size());

    sockfd = socket(AF_INET, SOCK_DGRAM, 0);
    if (sockfd < 0) {
        TRACE_ERR("NetUtil::getHwState(%s), open socket error!", intf.data());
        return NetHwState::Error;
    }
    rc = ioctl(sockfd, SIOCGIFFLAGS, &ifr);
    if (rc < 0) {
        TRACE_ERR("NetUtil::getHwState(%s), SIOCGIFFLAGS error: %s!",
                  intf.data(), ERRSTR);
        close(sockfd);
        return NetHwState::Error;
    }

    close(sockfd);
    if (ifr.ifr_flags & IFF_UP) {
        return NetHwState::Up;
    } else {
        return NetHwState::Down;
    }
}

NetLinkState NetUtil::getLinkState(const std::string& intf) {
    if (intf.size() >= IFNAMSIZ) {
        return NetLinkState::Error;
    }
    auto path = StrUtil::format("/sys/class/net/%s/operstate", intf.data());
    if (File::exists(path)) {
        auto result = File::readAll(path);
        if (result.find("up") != std::string::npos) {
            return NetLinkState::Connected;
        }
    }
    return NetLinkState::Disconnected;
}

bool NetUtil::setInetAddr(const std::string& intf, const std::string& ipaddr) {
    if (intf.size() >= IFNAMSIZ) {
        return false;
    }
    struct ifreq ifr;
    strncpy(ifr.ifr_name, intf.data(), intf.size());

    int sockfd = socket(AF_INET, SOCK_DGRAM, 0);
    if (sockfd < 0) {
        return false;
    }

    int rc = ioctl(sockfd, SIOCGIFADDR, &ifr);
    if (rc < 0) {
        TRACE_ERR("NetUtil::setInetAddr(%s), SIOCGIFADDR error: %s!",
                  intf.data(), ERRSTR);
        close(sockfd);
        return false;
    }
    struct sockaddr_in* addr = (struct sockaddr_in*)&(ifr.ifr_addr);
    addr->sin_family = AF_INET;
    rc = inet_aton(CSTR(ipaddr), &(addr->sin_addr));
    if (rc == 0) {
        TRACE_ERR("NetUtil::setInetAddr(%s), inet_aton error!", intf.data());
        close(sockfd);
        return false;
    }
    rc = ioctl(sockfd, SIOCSIFADDR, &ifr);
    if (rc < 0) {
        TRACE_ERR("NetUtil::setInetAddr(%s), SIOCSIFADDR error: %s!",
                  intf.data(), ERRSTR);
        close(sockfd);
        return false;
    }
    close(sockfd);
    return true;
}

bool NetUtil::setMaskAddr(const std::string& intf, const std::string& netmask) {
    if (intf.size() >= IFNAMSIZ) {
        return false;
    }

    struct ifreq ifr;
    strncpy(ifr.ifr_name, intf.data(), intf.size());

    int sockfd = socket(AF_INET, SOCK_DGRAM, 0);
    if (sockfd < 0) {
        return false;
    }

    int rc = ioctl(sockfd, SIOCGIFADDR, &ifr);
    if (rc < 0) {
        TRACE_ERR("NetUtil::setMaskAddr(%s), SIOCGIFADDR error: %s!",
                  intf.data(), ERRSTR);
        close(sockfd);
        return false;
    }
    struct sockaddr_in* addr = (struct sockaddr_in*)&(ifr.ifr_addr);
    addr->sin_family = AF_INET;
    rc = inet_aton(CSTR(netmask), &(addr->sin_addr));
    if (rc == 0) {
        TRACE_ERR("NetUtil::setMaskAddr(%s), inet_aton error!", intf.data());
        close(sockfd);
        return false;
    }

    rc = ioctl(sockfd, SIOCSIFNETMASK, &ifr);
    if (rc < 0) {
        close(sockfd);
        return false;
    }
    close(sockfd);
    return true;
}

bool NetUtil::setMacAddr(const std::string& intf, const std::string& mac) {
    char macaddr[6] = {0};
    if (intf.size() >= IFNAMSIZ || mac.size() != 17) {
        return false;
    }
    struct ifreq ifr;
    strncpy(ifr.ifr_name, intf.data(), intf.size());

    int sockfd = socket(AF_INET, SOCK_DGRAM, 0);
    if (sockfd < 0) {
        return false;
    }

    const char* pMacStr = CSTR(mac);
    for (auto i = 0; i < 6; i++) {
        *(macaddr + i) = StrUtil::ctoi(*(pMacStr + 3 * i)) * 16 +
                         StrUtil::ctoi(*(pMacStr + 3 * i + 1));
    }

    ifr.ifr_hwaddr.sa_family = ARPHRD_ETHER;
    memcpy((unsigned char*)ifr.ifr_hwaddr.sa_data, macaddr, 6);
    if (ioctl(sockfd, SIOCSIFHWADDR, &ifr) < 0) {
        TRACE_ERR("NetUtil::setMacAddr(%s), SIOCSIFHWADDR error: %s!",
                  intf.data(), ERRSTR);
        close(sockfd);
        return false;
    }
    close(sockfd);
    return true;
}

bool NetUtil::setHwState(const std::string& intf, NetHwState state) {
    if (intf.size() >= IFNAMSIZ) {
        return false;
    }
    struct ifreq ifr;
    strncpy(ifr.ifr_name, intf.data(), intf.size());

    int sockfd = socket(AF_INET, SOCK_DGRAM, 0);
    if (sockfd < 0) {
        return false;
    }
    /* 先读取ifflags */
    int rc = ioctl(sockfd, SIOCGIFFLAGS, &ifr);
    if (rc < 0) {
        TRACE_ERR("NetUtil::setHwState(%s), SIOCGIFFLAGS error: %s!",
                  intf.data(), ERRSTR);
        close(sockfd);
        return false;
    }

    /* 再设置ifflags */
    if (state == NetHwState::Down) {
        ifr.ifr_flags &= (~IFF_UP);
    } else {
        ifr.ifr_flags |= IFF_UP;
    }
    rc = ioctl(sockfd, SIOCSIFFLAGS, &ifr);
    if (rc < 0) {
        TRACE_ERR("NetUtil::setHwState(%s), SIOCSIFFLAGS error: %s!",
                  intf.data(), ERRSTR);
        close(sockfd);
        return false;
    }

    close(sockfd);
    return true;
}

std::string NetUtil::getHostIP(const std::string& hostName) {
    auto splits = StrUtil::splitString(hostName, ":");
    std::string port = "";
    if (splits.size() >= 2) {
        port = std::string(":") + splits[1];
    }
    struct hostent* host = gethostbyname(CSTR(splits[0]));
    if (!host) {
        return "";
    }
    return std::string(inet_ntoa(*((struct in_addr*)host->h_addr)))
        .append(port);
}
}  // namespace appkit
